{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-24T09:14:03.893310Z",
     "start_time": "2024-11-24T09:14:03.890548Z"
    }
   },
   "cell_type": "code",
   "source": [
    "%env DEEPSEEK_BASE_URL=https://api.deepseek.com/v1\n",
    "%env DEEPSEEK_API_KEY=替换为Deepseek API Key，用于构建Contextual Embedding\n",
    "\n",
    "%env LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n",
    "%env LLM_API_KEY=替换为阿里云API Key，用于评估"
   ],
   "id": "16cef7f81941619c",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: DEEPSEEK_BASE_URL=https://api.deepseek.com/v1\n",
      "env: DEEPSEEK_API_KEY=替换为Deepseek API Key，用于构建Contextual Embedding\n",
      "env: LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n",
      "env: LLM_API_KEY=替换为阿里云API Key，用于评估\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "af375836-b870-458b-87d1-4e00565977eb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:09.806352Z",
     "iopub.status.busy": "2024-11-24T08:33:09.805645Z",
     "iopub.status.idle": "2024-11-24T08:33:09.816463Z",
     "shell.execute_reply": "2024-11-24T08:33:09.814764Z",
     "shell.execute_reply.started": "2024-11-24T08:33:09.806282Z"
    },
    "papermill": {
     "duration": 0.115454,
     "end_time": "2024-11-23T14:29:00.919641",
     "exception": false,
     "start_time": "2024-11-23T14:29:00.804187",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%capture --no-stderr\n",
    "!pip install -U langchain langchain-community langchain-openai langchain-cohere pypdf sentence_transformers chromadb shutil openpyxl FlagEmbedding cohere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1e2c72b8-ee12-4130-af88-699998aa230c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:09.818726Z",
     "iopub.status.busy": "2024-11-24T08:33:09.818235Z",
     "iopub.status.idle": "2024-11-24T08:33:10.058036Z",
     "shell.execute_reply": "2024-11-24T08:33:10.057585Z",
     "shell.execute_reply.started": "2024-11-24T08:33:09.818678Z"
    },
    "papermill": {
     "duration": 0.319981,
     "end_time": "2024-11-23T14:29:01.380771",
     "exception": false,
     "start_time": "2024-11-23T14:29:01.060790",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "841d2b02-ad06-40d2-b11f-c7adccec6ca2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:10.058991Z",
     "iopub.status.busy": "2024-11-24T08:33:10.058590Z",
     "iopub.status.idle": "2024-11-24T08:33:10.184889Z",
     "shell.execute_reply": "2024-11-24T08:33:10.182750Z",
     "shell.execute_reply.started": "2024-11-24T08:33:10.058954Z"
    },
    "papermill": {
     "duration": 0.121409,
     "end_time": "2024-11-23T14:29:01.638126",
     "exception": false,
     "start_time": "2024-11-23T14:29:01.516717",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "expr_version = 'retrieval_v13_contextual_embeddings_deepseek'\n",
    "\n",
    "preprocess_output_dir = os.path.join(os.path.pardir, 'outputs', 'v1_20240713')\n",
    "expr_dir = os.path.join(os.path.pardir, 'experiments', expr_version)\n",
    "\n",
    "os.makedirs(expr_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf7e81e3-4c82-4842-aef5-7592caaf1d39",
   "metadata": {
    "papermill": {
     "duration": 0.100379,
     "end_time": "2024-11-23T14:29:01.862379",
     "exception": false,
     "start_time": "2024-11-23T14:29:01.762000",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# 读取文档"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e6920e29-bc7d-4635-be06-d151eaf0e100",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:10.191849Z",
     "iopub.status.busy": "2024-11-24T08:33:10.191113Z",
     "iopub.status.idle": "2024-11-24T08:33:11.626502Z",
     "shell.execute_reply": "2024-11-24T08:33:11.626037Z",
     "shell.execute_reply.started": "2024-11-24T08:33:10.191779Z"
    },
    "papermill": {
     "duration": 2.012298,
     "end_time": "2024-11-23T14:29:03.974974",
     "exception": false,
     "start_time": "2024-11-23T14:29:01.962676",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import PyPDFLoader\n",
    "\n",
    "loader = PyPDFLoader(os.path.join(os.path.pardir, 'data', '2024全球经济金融展望报告.pdf'))\n",
    "documents = loader.load()\n",
    "\n",
    "qa_df = pd.read_excel(os.path.join(preprocess_output_dir, 'question_answer.xlsx'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841ec659-4ad7-4e1f-b1ea-3477bf97fde3",
   "metadata": {
    "papermill": {
     "duration": 0.100297,
     "end_time": "2024-11-23T14:29:04.219302",
     "exception": false,
     "start_time": "2024-11-23T14:29:04.119005",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# 文档切分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74fe856a-7c19-4c3c-bb30-7abfa6298f74",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:11.627222Z",
     "iopub.status.busy": "2024-11-24T08:33:11.627056Z",
     "iopub.status.idle": "2024-11-24T08:33:11.633600Z",
     "shell.execute_reply": "2024-11-24T08:33:11.633183Z",
     "shell.execute_reply.started": "2024-11-24T08:33:11.627209Z"
    },
    "papermill": {
     "duration": 0.109229,
     "end_time": "2024-11-23T14:29:04.429069",
     "exception": false,
     "start_time": "2024-11-23T14:29:04.319840",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from uuid import uuid4\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "def split_docs(documents, filepath, chunk_size=400, chunk_overlap=40, seperators=['\\n\\n\\n', '\\n\\n'], force_split=False):\n",
    "    if os.path.exists(filepath) and not force_split:\n",
    "        print('found cache, restoring...')\n",
    "        return pickle.load(open(filepath, 'rb'))\n",
    "\n",
    "    splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=chunk_size,\n",
    "        chunk_overlap=chunk_overlap,\n",
    "        separators=seperators\n",
    "    )\n",
    "    split_docs = splitter.split_documents(documents)\n",
    "    for chunk in split_docs:\n",
    "        chunk.metadata['uuid'] = str(uuid4())\n",
    "\n",
    "    pickle.dump(split_docs, open(filepath, 'wb'))\n",
    "\n",
    "    return split_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "aa25540d-0504-4ae7-9804-9e3862b132d5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:11.634306Z",
     "iopub.status.busy": "2024-11-24T08:33:11.634112Z",
     "iopub.status.idle": "2024-11-24T08:33:11.645063Z",
     "shell.execute_reply": "2024-11-24T08:33:11.644645Z",
     "shell.execute_reply.started": "2024-11-24T08:33:11.634279Z"
    },
    "papermill": {
     "duration": 0.145583,
     "end_time": "2024-11-23T14:29:04.677429",
     "exception": false,
     "start_time": "2024-11-23T14:29:04.531846",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found cache, restoring...\n"
     ]
    }
   ],
   "source": [
    "splitted_docs = split_docs(documents, os.path.join(preprocess_output_dir, 'split_docs.pkl'), chunk_size=500, chunk_overlap=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "21ae2e70-72f3-44e4-b1af-85348e8e70c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:11.645682Z",
     "iopub.status.busy": "2024-11-24T08:33:11.645565Z",
     "iopub.status.idle": "2024-11-24T08:33:11.651613Z",
     "shell.execute_reply": "2024-11-24T08:33:11.651192Z",
     "shell.execute_reply.started": "2024-11-24T08:33:11.645670Z"
    },
    "papermill": {
     "duration": 0.255459,
     "end_time": "2024-11-23T14:29:05.065932",
     "exception": false,
     "start_time": "2024-11-23T14:29:04.810473",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import re\n",
    "from langchain.schema import Document\n",
    "\n",
    "# 去除页眉\n",
    "pattern = r\"^全球经济金融展望报告\\n中国银行研究院 \\d+ 2024年\"\n",
    "doc_content = '\\n'.join(re.sub(pattern, '', doc.page_content) for doc in documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "220dbc3a-fceb-4e49-a3f1-01e16660b2a6",
   "metadata": {
    "papermill": {
     "duration": 0.100209,
     "end_time": "2024-11-23T14:29:05.255871",
     "exception": false,
     "start_time": "2024-11-23T14:29:05.155662",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# 检索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8598a11c-25d8-4af1-a98b-06a8c394e261",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:11.652496Z",
     "iopub.status.busy": "2024-11-24T08:33:11.652154Z",
     "iopub.status.idle": "2024-11-24T08:33:12.481875Z",
     "shell.execute_reply": "2024-11-24T08:33:12.481420Z",
     "shell.execute_reply.started": "2024-11-24T08:33:11.652484Z"
    },
    "papermill": {
     "duration": 0.989203,
     "end_time": "2024-11-23T14:29:06.345534",
     "exception": false,
     "start_time": "2024-11-23T14:29:05.356331",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "from langchain_community.vectorstores import Chroma\n",
    "import torch\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "# device = 'cpu'\n",
    "print(f'device: {device}')\n",
    "\n",
    "def get_embeddings(model_path):\n",
    "    embeddings = HuggingFaceBgeEmbeddings(\n",
    "        model_name=model_path,\n",
    "        model_kwargs={'device': device},\n",
    "        encode_kwargs={'normalize_embeddings': True},\n",
    "        # show_progress=True\n",
    "        query_instruction='为这个句子生成表示以用于检索相关文章：'\n",
    "    )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e45931bb-6812-4680-9e73-2c2e63a7bbcb",
   "metadata": {
    "papermill": {
     "duration": 0.111733,
     "end_time": "2024-11-23T14:29:06.613089",
     "exception": false,
     "start_time": "2024-11-23T14:29:06.501356",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## 构建Contextual Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "52174650-948f-4b11-9292-b729e65a97ea",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:12.482722Z",
     "iopub.status.busy": "2024-11-24T08:33:12.482392Z",
     "iopub.status.idle": "2024-11-24T08:33:12.503790Z",
     "shell.execute_reply": "2024-11-24T08:33:12.503288Z",
     "shell.execute_reply.started": "2024-11-24T08:33:12.482709Z"
    },
    "papermill": {
     "duration": 0.402828,
     "end_time": "2024-11-23T14:29:07.116501",
     "exception": false,
     "start_time": "2024-11-23T14:29:06.713673",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d20b868a-2fdc-4f11-9a8f-14286a5857de",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:27.480746Z",
     "iopub.status.busy": "2024-11-24T08:33:27.480491Z",
     "iopub.status.idle": "2024-11-24T08:33:27.499417Z",
     "shell.execute_reply": "2024-11-24T08:33:27.498952Z",
     "shell.execute_reply.started": "2024-11-24T08:33:27.480725Z"
    },
    "papermill": {
     "duration": 0.114426,
     "end_time": "2024-11-23T14:29:07.361813",
     "exception": false,
     "start_time": "2024-11-23T14:29:07.247387",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(\n",
    "    model='deepseek-chat',\n",
    "    base_url=os.getenv('DEEPSEEK_BASE_URL'),\n",
    "    api_key=os.getenv('DEEPSEEK_API_KEY'),\n",
    "    temperature=0.001\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "26eaae03-f566-4358-b0ab-d2812993d743",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:28.128345Z",
     "iopub.status.busy": "2024-11-24T08:33:28.128173Z",
     "iopub.status.idle": "2024-11-24T08:33:28.131218Z",
     "shell.execute_reply": "2024-11-24T08:33:28.130874Z",
     "shell.execute_reply.started": "2024-11-24T08:33:28.128331Z"
    },
    "papermill": {
     "duration": 0.163967,
     "end_time": "2024-11-23T14:29:07.635025",
     "exception": false,
     "start_time": "2024-11-23T14:29:07.471058",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "DOCUMENT_CONTEXT_PROMPT = \"\"\"\n",
    "<document>\n",
    "{doc_content}\n",
    "</document>\n",
    "\"\"\"\n",
    "\n",
    "CHUNK_CONTEXT_PROMPT = \"\"\"\n",
    "这是我们想要放置在整个文档中的片段。\n",
    "<chunk>\n",
    "{chunk_content}\n",
    "</chunk>\n",
    "\n",
    "请为了提高检索效率，给出一个简短而精确的上下文来定位这个片段在整个文件中的位置。\n",
    "\"\"\"\n",
    "\n",
    "def situate_context(doc: str, chunk: str) -> str:\n",
    "    messages = [\n",
    "        (\"user\", DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc)),\n",
    "        (\"user\", CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk)),\n",
    "    ]\n",
    "    return llm.invoke(messages).content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b614a22a-ee62-412d-8320-69b56708f765",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:30.600246Z",
     "iopub.status.busy": "2024-11-24T08:33:30.600040Z",
     "iopub.status.idle": "2024-11-24T08:33:34.216209Z",
     "shell.execute_reply": "2024-11-24T08:33:34.215331Z",
     "shell.execute_reply.started": "2024-11-24T08:33:30.600232Z"
    },
    "papermill": {
     "duration": 7.095007,
     "end_time": "2024-11-23T14:29:14.833418",
     "exception": false,
     "start_time": "2024-11-23T14:29:07.738411",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'这个片段位于整个文件的开头部分，是报告的标题、摘要和研究团队介绍部分。它包含了报告的发布日期、主要内容概述以及研究团队的成员名单和联系方式。'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "situate_context(doc_content, chunk=splitted_docs[0].page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "473c2187-ae9a-4223-b1f6-e54008b40149",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:33:38.143779Z",
     "iopub.status.busy": "2024-11-24T08:33:38.143467Z",
     "iopub.status.idle": "2024-11-24T08:37:59.532464Z",
     "shell.execute_reply": "2024-11-24T08:37:59.531405Z",
     "shell.execute_reply.started": "2024-11-24T08:33:38.143752Z"
    },
    "papermill": {
     "duration": 266.750722,
     "end_time": "2024-11-23T14:33:41.712280",
     "exception": false,
     "start_time": "2024-11-23T14:29:14.961558",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "204f3d0c3af74daab3bcae1d05287e4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/52 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "contextualized_chunks = []\n",
    "for chunk in tqdm(splitted_docs):\n",
    "    contextualized_text = situate_context(doc_content, chunk)\n",
    "    contextualized_chunks.append(\n",
    "        Document(\n",
    "            page_content=chunk.page_content + '\\n\\n' + contextualized_text,\n",
    "            metadata=chunk.metadata\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d5ae94fd-fb02-4f67-b489-298fb7b2f48f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:41:26.484295Z",
     "iopub.status.busy": "2024-11-24T08:41:26.483523Z",
     "iopub.status.idle": "2024-11-24T08:41:26.497446Z",
     "shell.execute_reply": "2024-11-24T08:41:26.495326Z",
     "shell.execute_reply.started": "2024-11-24T08:41:26.484225Z"
    },
    "papermill": {
     "duration": 0.121981,
     "end_time": "2024-11-23T14:33:41.973812",
     "exception": false,
     "start_time": "2024-11-23T14:33:41.851831",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Document(metadata={'source': 'data/2024全球经济金融展望报告.pdf', 'page': 0, 'uuid': 'e73a0c9d-d42b-4350-a4c3-b38bf67c68a5'}, page_content='研究院\\n全球经济金融展望报告\\n要点2024年年报（总第57期） 报告日期：2023年12月12日\\n●2023年全球经济增长动力持续回落，各国复苏分化，\\n发达经济体增速明显放缓，新兴经济体整体表现稳定。\\n全球贸易增长乏力，各国生产景气度逐渐回落，内需\\n对经济的拉动作用减弱。欧美央行货币政策紧缩态势\\n放缓，美元指数高位震荡后走弱，全球股市表现总体\\n好于预期，但区域分化明显。高利率环境抑制债券融\\n资需求，债券违约风险持续上升。\\n●展望2024年，预计全球经济复苏将依旧疲软，主要\\n经济体增长态势和货币政策走势将进一步分化。欧美\\n央行大概率结束本轮紧缩货币周期，美元指数将逐步\\n走弱，流向新兴经济体的跨境资本将增加。国际原油\\n市场短缺格局或延续，新能源发展成为重点。\\n●海湾六国经济发展与投资前景、高利率和高债务对\\n美国房地产市场脆弱性的影响等热点问题值得关注。中国银行研究院\\n全球经济金融研究课题组\\n组长：陈卫东\\n副组长：钟红\\n廖淑萍\\n成员：边卫红\\n熊启跃\\n王有鑫\\n曹鸿宇\\n李颖婷\\n王宁远\\n初晓\\n章凯莉\\n黄小军（纽约）\\n陆晓明（纽约）\\n黄承煜（纽约）\\n宋达志（伦敦）\\n李振龙（伦敦）\\n张传捷（伦敦）\\n刘冰彦（法兰克福）\\n温颍坤（法兰克福）\\n张明捷（法兰克福）\\n王哲（东京）\\n李彧（香港）\\n黎永康（香港）\\n联系人：王有鑫\\n电话：010-66594127\\n邮件：wangyouxin_hq@bank-of-china.com主要经济体GDP增速变化趋势（%）\\n资料来源：IMF，中国银行研究院\\n\\n这个片段位于文件的开头部分，是报告的标题、要点总结、报告日期、主要内容概述以及研究课题组成员名单。')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "contextualized_chunks[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f6f46c73-7369-448f-a89a-ed3d817cad47",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:41:27.274575Z",
     "iopub.status.busy": "2024-11-24T08:41:27.274362Z",
     "iopub.status.idle": "2024-11-24T08:41:43.601206Z",
     "shell.execute_reply": "2024-11-24T08:41:43.598749Z",
     "shell.execute_reply.started": "2024-11-24T08:41:27.274561Z"
    },
    "papermill": {
     "duration": 83.983138,
     "end_time": "2024-11-23T14:35:06.117207",
     "exception": false,
     "start_time": "2024-11-23T14:33:42.134069",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from langchain_community.vectorstores import Chroma\n",
    "\n",
    "model_path = 'BAAI/bge-large-zh-v1.5'\n",
    "embeddings = get_embeddings(model_path)\n",
    "\n",
    "persist_directory = os.path.join(expr_dir, 'chroma', 'bge')\n",
    "shutil.rmtree(persist_directory, ignore_errors=True)\n",
    "vector_db = Chroma.from_documents(\n",
    "    contextualized_chunks,\n",
    "    embedding=embeddings,\n",
    "    persist_directory=persist_directory\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "566c6f3c-5777-4aa9-bc60-a3ee23050506",
   "metadata": {
    "papermill": {
     "duration": 0.10028,
     "end_time": "2024-11-23T14:35:06.350323",
     "exception": false,
     "start_time": "2024-11-23T14:35:06.250043",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## 计算检索准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b03e3382-39e9-4932-a265-69b811041629",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:41:59.352621Z",
     "iopub.status.busy": "2024-11-24T08:41:59.352196Z",
     "iopub.status.idle": "2024-11-24T08:41:59.357138Z",
     "shell.execute_reply": "2024-11-24T08:41:59.356558Z",
     "shell.execute_reply.started": "2024-11-24T08:41:59.352598Z"
    },
    "papermill": {
     "duration": 0.110628,
     "end_time": "2024-11-23T14:35:06.583241",
     "exception": false,
     "start_time": "2024-11-23T14:35:06.472613",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_df = qa_df[(qa_df['dataset'] == 'test') & (qa_df['qa_type'] == 'detailed')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "32c3ad14-b217-44aa-bdb9-909b9d559668",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:41:59.893981Z",
     "iopub.status.busy": "2024-11-24T08:41:59.893615Z",
     "iopub.status.idle": "2024-11-24T08:41:59.901916Z",
     "shell.execute_reply": "2024-11-24T08:41:59.900493Z",
     "shell.execute_reply.started": "2024-11-24T08:41:59.893949Z"
    },
    "papermill": {
     "duration": 0.138041,
     "end_time": "2024-11-23T14:35:06.823783",
     "exception": false,
     "start_time": "2024-11-23T14:35:06.685742",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_hit_stat_df(top_k_arr=list(range(1, 9))):\n",
    "    hit_stat_data = []\n",
    "\n",
    "    for k in tqdm(top_k_arr):\n",
    "        for idx, row in test_df.iterrows():\n",
    "            question = row['question']\n",
    "            true_uuid = row['uuid']\n",
    "            # chunks = retrieve_fn(question, k=k)\n",
    "            chunks = vector_db.similarity_search(question, k=k)\n",
    "            retrieved_uuids = [doc.metadata['uuid'] for doc in chunks]\n",
    "\n",
    "            hit_stat_data.append({\n",
    "                'question': question,\n",
    "                'top_k': k,\n",
    "                'hit': int(true_uuid in retrieved_uuids),\n",
    "                'retrieved_chunks': len(chunks)\n",
    "            })\n",
    "    hit_stat_df = pd.DataFrame(hit_stat_data)\n",
    "    return hit_stat_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a63797c7-4151-4f55-8e5d-080c34265393",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:42:00.672108Z",
     "iopub.status.busy": "2024-11-24T08:42:00.671349Z",
     "iopub.status.idle": "2024-11-24T08:42:18.197413Z",
     "shell.execute_reply": "2024-11-24T08:42:18.196994Z",
     "shell.execute_reply.started": "2024-11-24T08:42:00.672039Z"
    },
    "papermill": {
     "duration": 18.284887,
     "end_time": "2024-11-23T14:35:25.215134",
     "exception": false,
     "start_time": "2024-11-23T14:35:06.930247",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa4ae1e0a8ab4741ab0dbe11b0a490b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "hit_stat_df = get_hit_stat_df()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d4890789-a44c-41de-b17f-0ff505788494",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:42:18.198199Z",
     "iopub.status.busy": "2024-11-24T08:42:18.198065Z",
     "iopub.status.idle": "2024-11-24T08:42:18.204605Z",
     "shell.execute_reply": "2024-11-24T08:42:18.204291Z",
     "shell.execute_reply.started": "2024-11-24T08:42:18.198186Z"
    },
    "papermill": {
     "duration": 0.145649,
     "end_time": "2024-11-23T14:35:25.499527",
     "exception": false,
     "start_time": "2024-11-23T14:35:25.353878",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>top_k</th>\n",
       "      <th>hit_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.397849</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.602151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.720430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.763441</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.806452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.817204</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.827957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.838710</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   top_k  hit_rate\n",
       "0      1  0.397849\n",
       "1      2  0.602151\n",
       "2      3  0.720430\n",
       "3      4  0.763441\n",
       "4      5  0.806452\n",
       "5      6  0.817204\n",
       "6      7  0.827957\n",
       "7      8  0.838710"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hit_stat_df.groupby(['top_k'])['hit'].mean().reset_index().rename(columns={'hit': 'hit_rate'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b0b086d1-6cec-4743-8df6-2ab3b1593689",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:42:18.205158Z",
     "iopub.status.busy": "2024-11-24T08:42:18.205031Z",
     "iopub.status.idle": "2024-11-24T08:42:18.570682Z",
     "shell.execute_reply": "2024-11-24T08:42:18.570162Z",
     "shell.execute_reply.started": "2024-11-24T08:42:18.205145Z"
    },
    "papermill": {
     "duration": 0.525548,
     "end_time": "2024-11-23T14:35:26.224632",
     "exception": false,
     "start_time": "2024-11-23T14:35:25.699084",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='top_k', ylabel='hit'>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGxCAYAAACeKZf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAor0lEQVR4nO3df3RU9Z3/8ddkIAkICWJIAjGSAgoEIaGJSQO1sHVsFj2s7G7daKlJR0xPK2PR+eoXIpqIPxhsbQxHKREk4ldLie3ijy4Yf0wNlhINJtJCFZQqBIEEcqwEok50Zr5/eDp2loCAydzJJ8/HOfcc5uZzM+9bj/XJnTsztmAwGBQAAIAhYqweAAAAoCcRNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMMsDqASItEAjo4MGDGjp0qGw2m9XjAACA0xAMBnXs2DGNGjVKMTGnvjbT7+Lm4MGDSk9Pt3oMAABwFvbv36/zzz//lGv6XdwMHTpU0hf/4yQkJFg8DQAAOB0dHR1KT08P/Xf8VPpd3PzjpaiEhATiBgCAPuZ0binhhmIAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYZYPUAAACgZ9x1111Wj9Ajvu55cOUGAAAYhbgBAABGIW4AAIBRuOcGAGCct+/7g9Uj9IiJi79r9Qh9ElduAACAUYgbAABgFOIGAAAYhbgBAABG4YZiADDYfT/8vtUj9IjFT/7O6hHQh3DlBgAAGMXyuFmxYoUyMjIUHx+v/Px8NTY2nnJ9VVWVxo8fr0GDBik9PV233HKLPv300whNCwAAop2lcVNbWyu3262Kigo1NzcrKytLhYWFOnz4cLfr161bp0WLFqmiokJvv/221qxZo9raWt1+++0RnhwAAEQrS++5qaysVGlpqZxOpySpurpaGzduVE1NjRYtWnTC+q1bt2r69On6wQ9+IEnKyMjQtddeq9dffz2icwPoex7+P7+3eoQe4frlbKtHAKKeZVduurq61NTUJIfD8eUwMTFyOBxqaGjo9php06apqakp9NLVe++9p02bNumKK66IyMwAACD6WXblpr29XX6/XykpKWH7U1JStGvXrm6P+cEPfqD29nZ9+9vfVjAY1Oeff66f/OQnp3xZyufzyefzhR53dHT0zAkAAICoZPkNxWeivr5eS5cu1a9+9Ss1Nzdrw4YN2rhxo+65556THuPxeJSYmBja0tPTIzgxAACINMuu3CQlJclut6utrS1sf1tbm1JTU7s95s4779R1112nG264QZI0efJkdXZ26sc//rEWL16smJgTW62srExutzv0uKOjg8ABAMBgll25iY2NVU5Ojrxeb2hfIBCQ1+tVQUFBt8d8/PHHJwSM3W6XJAWDwW6PiYuLU0JCQtgGAADMZem7pdxut0pKSpSbm6u8vDxVVVWps7Mz9O6p4uJipaWlyePxSJJmz56tyspKTZ06Vfn5+dqzZ4/uvPNOzZ49OxQ5AACgf7M0boqKinTkyBGVl5ertbVV2dnZqqurC91k3NLSEnal5o477pDNZtMdd9yhAwcOaMSIEZo9e7buu+8+q04BAABEGcu/W8rlcsnlcnX7s/r6+rDHAwYMUEVFhSoqKiIwGQAA6Iv61LulAAAAvgpxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADCK5Z9QDCCyNn9nhtUj9IgZr262egQAUYorNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKAOsHgCwyvSHpls9Qo/4001/snoEAIgqXLkBAABGIW4AAIBRiBsAAGCUqIibFStWKCMjQ/Hx8crPz1djY+NJ186cOVM2m+2E7corr4zgxAAAIFpZHje1tbVyu92qqKhQc3OzsrKyVFhYqMOHD3e7fsOGDTp06FBo27lzp+x2u66++uoITw4AAKKR5XFTWVmp0tJSOZ1OZWZmqrq6WoMHD1ZNTU2364cPH67U1NTQ9tJLL2nw4MHEDQAAkGRx3HR1dampqUkOhyO0LyYmRg6HQw0NDaf1O9asWaNrrrlG55xzTm+NCQAA+hBLP+emvb1dfr9fKSkpYftTUlK0a9eurzy+sbFRO3fu1Jo1a066xufzyefzhR53dHSc/cAAACDqWf6y1NexZs0aTZ48WXl5eSdd4/F4lJiYGNrS09MjOCEAAIg0S+MmKSlJdrtdbW1tYfvb2tqUmpp6ymM7Ozu1fv16zZs375TrysrKdPTo0dC2f//+rz03AACIXpbGTWxsrHJycuT1ekP7AoGAvF6vCgoKTnnsb3/7W/l8Pv3whz885bq4uDglJCSEbQAAwFyWf7eU2+1WSUmJcnNzlZeXp6qqKnV2dsrpdEqSiouLlZaWJo/HE3bcmjVrNGfOHJ133nlWjA0AAKKU5XFTVFSkI0eOqLy8XK2trcrOzlZdXV3oJuOWlhbFxIRfYNq9e7e2bNmiF1980YqRAQBAFLM8biTJ5XLJ5XJ1+7P6+voT9o0fP17BYLCXpwIAAH1Rn363FAAAwP9G3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKJbHzYoVK5SRkaH4+Hjl5+ersbHxlOs/+ugjzZ8/XyNHjlRcXJwuuugibdq0KULTAgCAaDfAyievra2V2+1WdXW18vPzVVVVpcLCQu3evVvJycknrO/q6tLll1+u5ORk/e53v1NaWpr27dunYcOGRX54AAAQlSyNm8rKSpWWlsrpdEqSqqurtXHjRtXU1GjRokUnrK+pqdGHH36orVu3auDAgZKkjIyMSI4MAACinGUvS3V1dampqUkOh+PLYWJi5HA41NDQ0O0xzz33nAoKCjR//nylpKTo4osv1tKlS+X3+yM1NgAAiHKWXblpb2+X3+9XSkpK2P6UlBTt2rWr22Pee+89/eEPf9DcuXO1adMm7dmzRzfeeKM+++wzVVRUdHuMz+eTz+cLPe7o6Oi5kzBEy92TrR6hR1xQvsPqEQAAUcDyG4rPRCAQUHJyslatWqWcnBwVFRVp8eLFqq6uPukxHo9HiYmJoS09PT2CEwMAgEizLG6SkpJkt9vV1tYWtr+trU2pqandHjNy5EhddNFFstvtoX0TJ05Ua2ururq6uj2mrKxMR48eDW379+/vuZMAAABRx7K4iY2NVU5Ojrxeb2hfIBCQ1+tVQUFBt8dMnz5de/bsUSAQCO175513NHLkSMXGxnZ7TFxcnBISEsI2AABgLktflnK73Vq9erUef/xxvf322/rpT3+qzs7O0LuniouLVVZWFlr/05/+VB9++KEWLFigd955Rxs3btTSpUs1f/58q04BAABEGUvfCl5UVKQjR46ovLxcra2tys7OVl1dXegm45aWFsXEfNlf6enpeuGFF3TLLbdoypQpSktL04IFC7Rw4UKrTgEAAEQZS+NGklwul1wuV7c/q6+vP2FfQUGBXnvttV6eCgAA9FV96t1SAAAAX4W4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRoiJuVqxYoYyMDMXHxys/P1+NjY0nXbt27VrZbLawLT4+PoLTAgCAaGZ53NTW1srtdquiokLNzc3KyspSYWGhDh8+fNJjEhISdOjQodC2b9++CE4MAACimeVxU1lZqdLSUjmdTmVmZqq6ulqDBw9WTU3NSY+x2WxKTU0NbSkpKRGcGAAARDNL46arq0tNTU1yOByhfTExMXI4HGpoaDjpccePH9fo0aOVnp6uq666Sn/9618jMS4AAOgDLI2b9vZ2+f3+E668pKSkqLW1tdtjxo8fr5qaGj377LN68sknFQgENG3aNH3wwQfdrvf5fOro6AjbAACAuSx/WepMFRQUqLi4WNnZ2ZoxY4Y2bNigESNG6JFHHul2vcfjUWJiYmhLT0+P8MQAACCSLI2bpKQk2e12tbW1he1va2tTamrqaf2OgQMHaurUqdqzZ0+3Py8rK9PRo0dD2/79+7/23AAAIHpZGjexsbHKycmR1+sN7QsEAvJ6vSooKDit3+H3+7Vjxw6NHDmy25/HxcUpISEhbAMAAOYaYPUAbrdbJSUlys3NVV5enqqqqtTZ2Smn0ylJKi4uVlpamjwejyTp7rvv1re+9S2NGzdOH330kX7xi19o3759uuGGG6w8DQAAECUsj5uioiIdOXJE5eXlam1tVXZ2turq6kI3Gbe0tCgm5ssLTH//+99VWlqq1tZWnXvuucrJydHWrVuVmZlp1SkAAIAoYnncSJLL5ZLL5er2Z/X19WGPH3zwQT344IMRmAoAAPRFfe7dUgAAAKdC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxyVnHz3e9+Vx999NEJ+zs6OvTd7373684EAABw1s4qburr69XV1XXC/k8//VR//OMfv/ZQAAAAZ2vAmSz+y1/+EvrzW2+9pdbW1tBjv9+vuro6paWl9dx0AAAAZ+iM4iY7O1s2m002m63bl58GDRqkhx56qMeGAwAAOFNnFDfvv/++gsGgxowZo8bGRo0YMSL0s9jYWCUnJ8tut/f4kAAAAKfrjOJm9OjRkqRAINArwwAAAHxdpx03zz33nGbNmqWBAwfqueeeO+Xaf/u3f/vagwEAAJyN046bOXPmqLW1VcnJyZozZ85J19lsNvn9/p6YDQAA4Iyddtz880tRvCwFAACi1Rndc/PPvF6vvF6vDh8+HBY7NptNa9as6ZHhAAAAztRZxc2SJUt09913Kzc3VyNHjpTNZuvpuQAAAM7KWcVNdXW11q5dq+uuu66n5wEAAPhazurrF7q6ujRt2rSengUAAOBrO6u4ueGGG7Ru3bqengUAAOBrO+2Xpdxud+jPgUBAq1at0ssvv6wpU6Zo4MCBYWsrKyt7bkIAAIAzcNpx8+abb4Y9zs7OliTt3LkzbD83FwMAACuddty88sorvTkHAABAjzire24AAACiVVTEzYoVK5SRkaH4+Hjl5+ersbHxtI5bv369bDbbKb8OAgAA9C+Wx01tba3cbrcqKirU3NysrKwsFRYW6vDhw6c8bu/evbr11lt16aWXRmhSAADQF1geN5WVlSotLZXT6VRmZqaqq6s1ePBg1dTUnPQYv9+vuXPnasmSJRozZkwEpwUAANHO0rjp6upSU1OTHA5HaF9MTIwcDocaGhpOetzdd9+t5ORkzZs3LxJjAgCAPuSsvzizJ7S3t8vv9yslJSVsf0pKinbt2tXtMVu2bNGaNWu0ffv203oOn88nn88XetzR0XHW8wIAgOhn+ctSZ+LYsWO67rrrtHr1aiUlJZ3WMR6PR4mJiaEtPT29l6cEAABWsvTKTVJSkux2u9ra2sL2t7W1KTU19YT1f/vb37R3717Nnj07tC8QCEiSBgwYoN27d2vs2LFhx5SVlYV9unJHR8dJAyfntv931ucSTZp+UWz1CAAAWMbSuImNjVVOTo68Xm/o7dyBQEBer1cul+uE9RMmTNCOHTvC9t1xxx06duyYli9f3m20xMXFKS4urlfmBwAA0cfSuJG++M6qkpIS5ebmKi8vT1VVVers7JTT6ZQkFRcXKy0tTR6PR/Hx8br44ovDjh82bJgknbAfAAD0T5bHTVFRkY4cOaLy8nK1trYqOztbdXV1oZuMW1paFBPTp24NAgAAFrI8biTJ5XJ1+zKUJNXX15/y2LVr1/b8QAAAoM/ikggAADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAo0RF3KxYsUIZGRmKj49Xfn6+GhsbT7p2w4YNys3N1bBhw3TOOecoOztbTzzxRASnBQAA0czyuKmtrZXb7VZFRYWam5uVlZWlwsJCHT58uNv1w4cP1+LFi9XQ0KC//OUvcjqdcjqdeuGFFyI8OQAAiEaWx01lZaVKS0vldDqVmZmp6upqDR48WDU1Nd2unzlzpv793/9dEydO1NixY7VgwQJNmTJFW7ZsifDkAAAgGlkaN11dXWpqapLD4Qjti4mJkcPhUENDw1ceHwwG5fV6tXv3bn3nO9/pzVEBAEAfMcDKJ29vb5ff71dKSkrY/pSUFO3ateukxx09elRpaWny+Xyy2+361a9+pcsvv7zbtT6fTz6fL/S4o6OjZ4YHAABRydK4OVtDhw7V9u3bdfz4cXm9Xrndbo0ZM0YzZ848Ya3H49GSJUsiPyQAALCEpXGTlJQku92utra2sP1tbW1KTU096XExMTEaN26cJCk7O1tvv/22PB5Pt3FTVlYmt9sdetzR0aH09PSeOQEAABB1LL3nJjY2Vjk5OfJ6vaF9gUBAXq9XBQUFp/17AoFA2EtP/ywuLk4JCQlhGwAAMJflL0u53W6VlJQoNzdXeXl5qqqqUmdnp5xOpySpuLhYaWlp8ng8kr54mSk3N1djx46Vz+fTpk2b9MQTT2jlypVWngYAAIgSlsdNUVGRjhw5ovLycrW2tio7O1t1dXWhm4xbWloUE/PlBabOzk7deOON+uCDDzRo0CBNmDBBTz75pIqKiqw6BQAAEEUsjxtJcrlccrlc3f6svr4+7PG9996re++9NwJTAQCAvsjyD/EDAADoScQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIwSFXGzYsUKZWRkKD4+Xvn5+WpsbDzp2tWrV+vSSy/Vueeeq3PPPVcOh+OU6wEAQP9iedzU1tbK7XaroqJCzc3NysrKUmFhoQ4fPtzt+vr6el177bV65ZVX1NDQoPT0dH3ve9/TgQMHIjw5AACIRpbHTWVlpUpLS+V0OpWZmanq6moNHjxYNTU13a7/9a9/rRtvvFHZ2dmaMGGCHn30UQUCAXm93ghPDgAAopGlcdPV1aWmpiY5HI7QvpiYGDkcDjU0NJzW7/j444/12Wefafjw4b01JgAA6EMGWPnk7e3t8vv9SklJCdufkpKiXbt2ndbvWLhwoUaNGhUWSP/M5/PJ5/OFHnd0dJz9wAAAIOpZ/rLU17Fs2TKtX79eTz/9tOLj47td4/F4lJiYGNrS09MjPCUAAIgkS+MmKSlJdrtdbW1tYfvb2tqUmpp6ymMfeOABLVu2TC+++KKmTJly0nVlZWU6evRoaNu/f3+PzA4AAKKTpXETGxurnJycsJuB/3FzcEFBwUmP+/nPf6577rlHdXV1ys3NPeVzxMXFKSEhIWwDAADmsvSeG0lyu90qKSlRbm6u8vLyVFVVpc7OTjmdTklScXGx0tLS5PF4JEn333+/ysvLtW7dOmVkZKi1tVWSNGTIEA0ZMsSy8wAAANHB8rgpKirSkSNHVF5ertbWVmVnZ6uuri50k3FLS4tiYr68wLRy5Up1dXXp+9//ftjvqaio0F133RXJ0QEAQBSyPG4kyeVyyeVydfuz+vr6sMd79+7t/YEAAECf1affLQUAAPC/ETcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMIrlcbNixQplZGQoPj5e+fn5amxsPOnav/71r/rP//xPZWRkyGazqaqqKnKDAgCAPsHSuKmtrZXb7VZFRYWam5uVlZWlwsJCHT58uNv1H3/8scaMGaNly5YpNTU1wtMCAIC+wNK4qaysVGlpqZxOpzIzM1VdXa3Bgwerpqam2/WXXHKJfvGLX+iaa65RXFxchKcFAAB9gWVx09XVpaamJjkcji+HiYmRw+FQQ0ODVWMBAIA+boBVT9ze3i6/36+UlJSw/SkpKdq1a1ePPY/P55PP5ws97ujo6LHfDQAAoo/lNxT3No/Ho8TExNCWnp5u9UgAAKAXWRY3SUlJstvtamtrC9vf1tbWozcLl5WV6ejRo6Ft//79Pfa7AQBA9LEsbmJjY5WTkyOv1xvaFwgE5PV6VVBQ0GPPExcXp4SEhLANAACYy7J7biTJ7XarpKREubm5ysvLU1VVlTo7O+V0OiVJxcXFSktLk8fjkfTFTchvvfVW6M8HDhzQ9u3bNWTIEI0bN86y8wAAANHD0rgpKirSkSNHVF5ertbWVmVnZ6uuri50k3FLS4tiYr68uHTw4EFNnTo19PiBBx7QAw88oBkzZqi+vj7S4wMAgChkadxIksvlksvl6vZn/ztYMjIyFAwGIzAVAADoq4x/txQAAOhfiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGCUq4mbFihXKyMhQfHy88vPz1djYeMr1v/3tbzVhwgTFx8dr8uTJ2rRpU4QmBQAA0c7yuKmtrZXb7VZFRYWam5uVlZWlwsJCHT58uNv1W7du1bXXXqt58+bpzTff1Jw5czRnzhzt3LkzwpMDAIBoZHncVFZWqrS0VE6nU5mZmaqurtbgwYNVU1PT7frly5frX//1X3Xbbbdp4sSJuueee/TNb35TDz/8cIQnBwAA0cjSuOnq6lJTU5McDkdoX0xMjBwOhxoaGro9pqGhIWy9JBUWFp50PQAA6F8GWPnk7e3t8vv9SklJCdufkpKiXbt2dXtMa2trt+tbW1u7Xe/z+eTz+UKPjx49Kknq6Og4Ya3f98kZzR+tuju3Uzn2qb+XJomsMz3vzz/5vJcmiawzPe/Oz/vneX/i+7iXJomsMz3vTz/7rJcmiawzPe/jn3b20iSRdabn/c//vevLujvvf+wLBoNfebylcRMJHo9HS5YsOWF/enq6BdNERuJDP7F6BGt4Eq2ewBKJC/vneSuxf573/11h9QTWuPep/vnPW/daPYA1li1bdtKfHTt2TIlf8e+/pXGTlJQku92utra2sP1tbW1KTU3t9pjU1NQzWl9WVia32x16HAgE9OGHH+q8886TzWb7mmdwZjo6OpSenq79+/crISEhos9tJc6b8+4POG/Ouz+w8ryDwaCOHTumUaNGfeVaS+MmNjZWOTk58nq9mjNnjqQv4sPr9crlcnV7TEFBgbxer26++ebQvpdeekkFBQXdro+Li1NcXFzYvmHDhvXE+GctISGhX/3L8A+cd//CefcvnHf/YtV5f9UVm3+w/GUpt9utkpIS5ebmKi8vT1VVVers7JTT6ZQkFRcXKy0tTR6PR5K0YMECzZgxQ7/85S915ZVXav369XrjjTe0atUqK08DAABECcvjpqioSEeOHFF5eblaW1uVnZ2turq60E3DLS0tion58k1d06ZN07p163THHXfo9ttv14UXXqhnnnlGF198sVWnAAAAoojlcSNJLpfrpC9D1dfXn7Dv6quv1tVXX93LU/W8uLg4VVRUnPAymek4b867P+C8Oe/+oK+cty14Ou+pAgAA6CMs/4RiAACAnkTcAAAAoxA3AADAKMRNBLz66quaPXu2Ro0aJZvNpmeeecbqkSLC4/Hokksu0dChQ5WcnKw5c+Zo9+7dVo/V61auXKkpU6aEPgeioKBAzz//vNVjRdyyZctks9nCPpPKRHfddZdsNlvYNmHCBKvHiogDBw7ohz/8oc477zwNGjRIkydP1htvvGH1WL0qIyPjhH/eNptN8+fPt3q0XuX3+3XnnXfqG9/4hgYNGqSxY8fqnnvuOa2vQrBCVLxbynSdnZ3KysrS9ddfr//4j/+wepyI2bx5s+bPn69LLrlEn3/+uW6//XZ973vf01tvvaVzzjnH6vF6zfnnn69ly5bpwgsvVDAY1OOPP66rrrpKb775piZNmmT1eBGxbds2PfLII5oyZYrVo0TEpEmT9PLLL4ceDxhg/v+1/v3vf9f06dP1L//yL3r++ec1YsQIvfvuuzr33HOtHq1Xbdu2TX7/l9/Ht3PnTl1++eV98h28Z+L+++/XypUr9fjjj2vSpEl644035HQ6lZiYqJ/97GdWj3cC8/8NjAKzZs3SrFmzrB4j4urq6sIer127VsnJyWpqatJ3vvMdi6bqfbNnzw57fN9992nlypV67bXX+kXcHD9+XHPnztXq1at1773944txBgwYcNKvgDHV/fffr/T0dD322GOhfd/4xjcsnCgyRowYEfZ42bJlGjt2rGbMmGHRRJGxdetWXXXVVbryyislfXEF6ze/+Y0aGxstnqx7vCyFiPnHN7IPHz7c4kkix+/3a/369ers7DzpV4SYZv78+bryyivlcDisHiVi3n33XY0aNUpjxozR3Llz1dLSYvVIve65555Tbm6urr76aiUnJ2vq1KlavXq11WNFVFdXl5588kldf/31Ef+uwkibNm2avF6v3nnnHUnSn//8Z23ZsiVq/+LOlRtERCAQ0M0336zp06f3i0+T3rFjhwoKCvTpp59qyJAhevrpp5WZmWn1WL1u/fr1am5u1rZt26weJWLy8/O1du1ajR8/XocOHdKSJUt06aWXaufOnRo6dKjV4/Wa9957TytXrpTb7dbtt9+ubdu26Wc/+5liY2NVUlJi9XgR8cwzz+ijjz7Sj370I6tH6XWLFi1SR0eHJkyYILvdLr/fr/vuu09z5861erRuETeIiPnz52vnzp3asmWL1aNExPjx47V9+3YdPXpUv/vd71RSUqLNmzcbHTj79+/XggUL9NJLLyk+Pt7qcSLmn//mOmXKFOXn52v06NF66qmnNG/ePAsn612BQEC5ublaunSpJGnq1KnauXOnqqur+03crFmzRrNmzTqtb6nu65566in9+te/1rp16zRp0iRt375dN998s0aNGhWV/7yJG/Q6l8ul//mf/9Grr76q888/3+pxIiI2Nlbjxo2TJOXk5Gjbtm1avny5HnnkEYsn6z1NTU06fPiwvvnNb4b2+f1+vfrqq3r44Yfl8/lkt9stnDAyhg0bposuukh79uyxepReNXLkyBNifeLEifrv//5viyaKrH379unll1/Whg0brB4lIm677TYtWrRI11xzjSRp8uTJ2rdvnzweD3GD/iUYDOqmm27S008/rfr6+n5xs+HJBAIB+Xw+q8foVZdddpl27NgRts/pdGrChAlauHBhvwgb6Ysbqv/2t7/puuuus3qUXjV9+vQTPtrhnXfe0ejRoy2aKLIee+wxJScnh26wNd3HH38c9iXWkmS32xUIBCya6NSImwg4fvx42N/i3n//fW3fvl3Dhw/XBRdcYOFkvWv+/Plat26dnn32WQ0dOlStra2SpMTERA0aNMji6XpPWVmZZs2apQsuuEDHjh3TunXrVF9frxdeeMHq0XrV0KFDT7if6pxzztF5551n9H1Wt956q2bPnq3Ro0fr4MGDqqiokN1u17XXXmv1aL3qlltu0bRp07R06VL913/9lxobG7Vq1SqtWrXK6tF6XSAQ0GOPPaaSkpJ+8bZ/6Yt3gd5333264IILNGnSJL355puqrKzU9ddfb/Vo3Qui173yyitBSSdsJSUlVo/Wq7o7Z0nBxx57zOrRetX1118fHD16dDA2NjY4YsSI4GWXXRZ88cUXrR7LEjNmzAguWLDA6jF6VVFRUXDkyJHB2NjYYFpaWrCoqCi4Z88eq8eKiN///vfBiy++OBgXFxecMGFCcNWqVVaPFBEvvPBCUFJw9+7dVo8SMR0dHcEFCxYEL7jggmB8fHxwzJgxwcWLFwd9Pp/Vo3WLbwUHAABG4XNuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgD0axkZGaqqqrJ6DAA9iLgBEDVmzpypm2++2eoxAPRxxA0AADAKcQMgKvzoRz/S5s2btXz5ctlsNtlsNu3du1ebN29WXl6e4uLiNHLkSC1atEiff/556LiZM2fK5XLJ5XIpMTFRSUlJuvPOO3W2X5v36KOPatiwYfJ6vT11agAijLgBEBWWL1+ugoIClZaW6tChQzp06JAGDhyoK664Qpdccon+/Oc/a+XKlVqzZo3uvffesGMff/xxDRgwQI2NjVq+fLkqKyv16KOPnvEMP//5z7Vo0SK9+OKLuuyyy3rq1ABE2ACrBwAASUpMTFRsbKwGDx6s1NRUSdLixYuVnp6uhx9+WDabTRMmTNDBgwe1cOFClZeXKybmi7+fpaen68EHH5TNZtP48eO1Y8cOPfjggyotLT3t51+4cKGeeOIJbd68WZMmTeqVcwQQGVy5ARC13n77bRUUFMhms4X2TZ8+XcePH9cHH3wQ2vetb30rbE1BQYHeffdd+f3+03qeX/7yl1q9erW2bNlC2AAGIG4A9HuXXnqp/H6/nnrqKatHAdADiBsAUSM2NjbsasvEiRPV0NAQdnPwn/70Jw0dOlTnn39+aN/rr78e9ntee+01XXjhhbLb7af1vHl5eXr++ee1dOlSPfDAA1/zLABYjbgBEDUyMjL0+uuva+/evWpvb9eNN96o/fv366abbtKuXbv07LPPqqKiQm63O3S/jSS1tLTI7XZr9+7d+s1vfqOHHnpICxYsOKPnnjZtmjZt2qQlS5bwoX5AH8cNxQCixq233qqSkhJlZmbqk08+0fvvv69NmzbptttuU1ZWloYPH6558+bpjjvuCDuuuLhYn3zyifLy8mS327VgwQL9+Mc/PuPn//a3v62NGzfqiiuukN1u10033dRTpwYggmzBs/0wCACIAjNnzlR2djZXWwCE8LIUAAAwCnEDwFh//OMfNWTIkJNuAMzEy1IAjPXJJ5/owIEDJ/35uHHjIjgNgEghbgAAgFF4WQoAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABglP8PS43syVO3pSgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.barplot(x='top_k', y='hit', data=hit_stat_df, errorbar=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7925564a-7d30-4914-baaf-4a00abb7686d",
   "metadata": {
    "papermill": {
     "duration": 0.109216,
     "end_time": "2024-11-23T14:35:26.464009",
     "exception": false,
     "start_time": "2024-11-23T14:35:26.354793",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# 生成答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "50404beb-3be0-4aaa-b124-8c7a52b84531",
   "metadata": {
    "editable": true,
    "execution": {
     "iopub.execute_input": "2024-11-24T08:43:02.034225Z",
     "iopub.status.busy": "2024-11-24T08:43:02.033790Z",
     "iopub.status.idle": "2024-11-24T08:43:02.048987Z",
     "shell.execute_reply": "2024-11-24T08:43:02.048571Z",
     "shell.execute_reply.started": "2024-11-24T08:43:02.034204Z"
    },
    "papermill": {
     "duration": 0.159318,
     "end_time": "2024-11-23T14:35:26.768506",
     "exception": false,
     "start_time": "2024-11-23T14:35:26.609188",
     "status": "completed"
    },
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def rag(llm, query, n_chunks=4):\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个金融分析师，擅长根据所获取的信息片段，对问题进行分析和推理。\n",
    "你的任务是根据所获取的信息片段（<<<<context>>><<<</context>>>之间的内容）回答问题。\n",
    "回答保持简洁，不必重复问题，不要添加描述性解释和与答案无关的任何内容。\n",
    "已知信息：\n",
    "<<<<context>>>\n",
    "{{knowledge}}\n",
    "<<<</context>>>\n",
    "\n",
    "问题：{{query}}\n",
    "请回答：\n",
    "\"\"\".strip()\n",
    "    chunks = vector_db.similarity_search(query, k=n_chunks)\n",
    "    prompt = prompt_tmpl.replace('{{knowledge}}', '\\n\\n'.join([doc.page_content for doc in chunks])).replace('{{query}}', query)\n",
    "    retry_count = 3\n",
    "\n",
    "    resp = ''\n",
    "    while retry_count > 0:\n",
    "        try:\n",
    "            resp = llm.invoke(prompt)\n",
    "            break\n",
    "        except Exception as e:\n",
    "            retry_count -= 1\n",
    "            sleeping_seconds = 2 ** (4 - retry_count)\n",
    "            print(f\"query={query}, error={e}, sleeping={sleeping_seconds}, remaining retry count={retry_count}\")\n",
    "            \n",
    "            time.sleep(sleeping_seconds)\n",
    "    \n",
    "    return resp, chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e5a804-2dc6-411c-ba71-6ccf765b2b73",
   "metadata": {
    "papermill": {
     "duration": 0.135973,
     "end_time": "2024-11-23T14:35:27.001401",
     "exception": false,
     "start_time": "2024-11-23T14:35:26.865428",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "27132c3b-0051-4df6-bf57-fd804acb8d17",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:43:03.163190Z",
     "iopub.status.busy": "2024-11-24T08:43:03.162368Z",
     "iopub.status.idle": "2024-11-24T08:43:03.279539Z",
     "shell.execute_reply": "2024-11-24T08:43:03.279049Z",
     "shell.execute_reply.started": "2024-11-24T08:43:03.163120Z"
    },
    "papermill": {
     "duration": 0.199165,
     "end_time": "2024-11-23T14:35:27.323500",
     "exception": false,
     "start_time": "2024-11-23T14:35:27.124335",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3078648/432419542.py:3: LangChainDeprecationWarning: The class `Ollama` was deprecated in LangChain 0.3.1 and will be removed in 1.0.0. An updated version of the class exists in the :class:`~langchain-ollama package and should be used instead. To use it run `pip install -U :class:`~langchain-ollama` and import as `from :class:`~langchain_ollama import OllamaLLM``.\n",
      "  ollama_llm = Ollama(\n"
     ]
    }
   ],
   "source": [
    "from langchain.llms import Ollama\n",
    "\n",
    "ollama_llm = Ollama(\n",
    "    model='qwen2:7b-instruct-32k',\n",
    "    # model='qwen2:7b-instruct',\n",
    "    base_url='http://localhost:11434',\n",
    "    top_k=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "166392d8-f801-4372-b8ad-3e79aef0b350",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:43:06.086104Z",
     "iopub.status.busy": "2024-11-24T08:43:06.085318Z",
     "iopub.status.idle": "2024-11-24T08:43:06.093284Z",
     "shell.execute_reply": "2024-11-24T08:43:06.092841Z",
     "shell.execute_reply.started": "2024-11-24T08:43:06.086033Z"
    },
    "papermill": {
     "duration": 0.141864,
     "end_time": "2024-11-23T14:35:27.564409",
     "exception": false,
     "start_time": "2024-11-23T14:35:27.422545",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "prediction_df = qa_df[qa_df['dataset'] == 'test'][['uuid', 'question', 'qa_type', 'answer']].rename(columns={'answer': 'ref_answer'})\n",
    "\n",
    "def predict(llm, prediction_df, n_chunks):\n",
    "    prediction_df = prediction_df.copy()\n",
    "    answer_dict = {}\n",
    "\n",
    "    for idx, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        uuid = row['uuid']\n",
    "        question = row['question']\n",
    "        answer, chunks = rag(llm, question, n_chunks=n_chunks)\n",
    "        assert len(chunks) <= n_chunks\n",
    "        answer_dict[question] = {\n",
    "            'uuid': uuid,\n",
    "            'ref_answer': row['ref_answer'],\n",
    "            'gen_answer': answer,\n",
    "            'chunks': chunks\n",
    "        }\n",
    "        \n",
    "    prediction_df.loc[:, 'gen_answer'] = prediction_df['question'].apply(lambda q: answer_dict[q]['gen_answer'])\n",
    "    prediction_df.loc[:, 'chunks'] = prediction_df['question'].apply(lambda q: answer_dict[q]['chunks'])\n",
    "\n",
    "    return prediction_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ca46d5f1-e698-457d-abb6-92d83cd59c66",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T08:45:06.327331Z",
     "iopub.status.busy": "2024-11-24T08:45:06.327066Z",
     "iopub.status.idle": "2024-11-24T09:02:54.279108Z",
     "shell.execute_reply": "2024-11-24T09:02:54.278634Z",
     "shell.execute_reply.started": "2024-11-24T08:45:06.327304Z"
    },
    "papermill": {
     "duration": 514.92352,
     "end_time": "2024-11-23T14:44:02.805529",
     "exception": false,
     "start_time": "2024-11-23T14:35:27.882009",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8ab76ed25e5413398af42ee27bc15c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df = predict(ollama_llm, prediction_df, n_chunks=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "4a1c867e-d59c-454d-93f5-2c8e23b666de",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:05:51.429013Z",
     "iopub.status.busy": "2024-11-24T09:05:51.428218Z",
     "iopub.status.idle": "2024-11-24T09:05:51.605687Z",
     "shell.execute_reply": "2024-11-24T09:05:51.603472Z",
     "shell.execute_reply.started": "2024-11-24T09:05:51.428940Z"
    },
    "papermill": {
     "duration": 0.325117,
     "end_time": "2024-11-23T14:44:03.281713",
     "exception": false,
     "start_time": "2024-11-23T14:44:02.956596",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving to ../experiments/retrieval_v13_contextual_embeddings_deepseek/pred_df.pkl\n"
     ]
    }
   ],
   "source": [
    "save_path = os.path.join(expr_dir, 'pred_df.pkl')\n",
    "\n",
    "if not os.path.exists(save_path):\n",
    "    print(f'saving to {save_path}')\n",
    "    pickle.dump(pred_df, open(save_path, 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d79e974-089f-4c08-ba5e-804f6542e06a",
   "metadata": {
    "papermill": {
     "duration": 0.14423,
     "end_time": "2024-11-23T14:44:03.513124",
     "exception": false,
     "start_time": "2024-11-23T14:44:03.368894",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "217568fe-c0e4-49eb-9a7c-9fdfbc033d8a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:06:09.585739Z",
     "iopub.status.busy": "2024-11-24T09:06:09.585317Z",
     "iopub.status.idle": "2024-11-24T09:06:09.601290Z",
     "shell.execute_reply": "2024-11-24T09:06:09.600860Z",
     "shell.execute_reply.started": "2024-11-24T09:06:09.585702Z"
    },
    "papermill": {
     "duration": 0.369729,
     "end_time": "2024-11-23T14:44:04.017198",
     "exception": false,
     "start_time": "2024-11-23T14:44:03.647469",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "import time\n",
    "\n",
    "judge_llm = ChatOpenAI(\n",
    "    api_key=os.environ['LLM_API_KEY'],\n",
    "    base_url=os.environ['LLM_BASE_URL'],\n",
    "    model_name='qwen2-72b-instruct',\n",
    "    temperature=0\n",
    ")\n",
    "\n",
    "def evaluate(prediction_df):\n",
    "    \"\"\"\n",
    "    对预测结果进行打分\n",
    "    :param prediction_df: 预测结果，需要包含问题，参考答案，生成的答案，列名分别为question, ref_answer, gen_answer\n",
    "    :return 打分模型原始返回结果\n",
    "    \"\"\"\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个经济学博士，现在我有一系列问题，有一个助手已经对这些问题进行了回答，你需要参照参考答案，评价这个助手的回答是否正确，仅回复“是”或“否”即可，不要带其他描述性内容或无关信息。\n",
    "问题：\n",
    "<question>\n",
    "{{question}}\n",
    "</question>\n",
    "\n",
    "参考答案：\n",
    "<ref_answer>\n",
    "{{ref_answer}}\n",
    "</ref_answer>\n",
    "\n",
    "助手回答：\n",
    "<gen_answer>\n",
    "{{gen_answer}}\n",
    "</gen_answer>\n",
    "请评价：\n",
    "    \"\"\"\n",
    "    results = []\n",
    "\n",
    "    for _, row in prediction_df.iterrows():\n",
    "        question = row['question']\n",
    "        ref_answer = row['ref_answer']\n",
    "        gen_answer = row['gen_answer']\n",
    "\n",
    "        prompt = prompt_tmpl.replace('{{question}}', question).replace('{{ref_answer}}', str(ref_answer)).replace('{{gen_answer}}', gen_answer).strip()\n",
    "        \n",
    "        retry_count = 3\n",
    "        result = ''\n",
    "        \n",
    "        while retry_count > 0:\n",
    "            try:\n",
    "                result = judge_llm.invoke(prompt).content\n",
    "                break\n",
    "            except Exception as e:\n",
    "                retry_count -= 1\n",
    "                sleeping_seconds = 2 ** (4 - retry_count)\n",
    "                print(f\"query={question}, error={e}, sleeping={sleeping_seconds}, remaining retry count={retry_count}\")\n",
    "                \n",
    "                time.sleep(sleeping_seconds)\n",
    "        \n",
    "        results.append(result)\n",
    "\n",
    "        time.sleep(1)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "71db81af-b8f9-47ba-958b-761896516605",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:06:09.825684Z",
     "iopub.status.busy": "2024-11-24T09:06:09.825512Z",
     "iopub.status.idle": "2024-11-24T09:08:39.973367Z",
     "shell.execute_reply": "2024-11-24T09:08:39.971103Z",
     "shell.execute_reply.started": "2024-11-24T09:06:09.825671Z"
    },
    "papermill": {
     "duration": 150.566109,
     "end_time": "2024-11-23T14:46:34.714324",
     "exception": false,
     "start_time": "2024-11-23T14:44:04.148215",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred_df['raw_score'] = evaluate(pred_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "7da1b98e-99aa-4e11-9297-91eac1c62493",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:08:39.978335Z",
     "iopub.status.busy": "2024-11-24T09:08:39.977573Z",
     "iopub.status.idle": "2024-11-24T09:08:39.984078Z",
     "shell.execute_reply": "2024-11-24T09:08:39.983669Z",
     "shell.execute_reply.started": "2024-11-24T09:08:39.978235Z"
    },
    "papermill": {
     "duration": 0.138037,
     "end_time": "2024-11-23T14:46:35.040595",
     "exception": false,
     "start_time": "2024-11-23T14:46:34.902558",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['是', '否'], dtype=object)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['raw_score'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "2c99c078-d294-40b8-b57b-31cfd7349c3e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:08:39.984732Z",
     "iopub.status.busy": "2024-11-24T09:08:39.984584Z",
     "iopub.status.idle": "2024-11-24T09:08:40.003530Z",
     "shell.execute_reply": "2024-11-24T09:08:40.001078Z",
     "shell.execute_reply.started": "2024-11-24T09:08:39.984721Z"
    },
    "papermill": {
     "duration": 0.107466,
     "end_time": "2024-11-23T14:46:35.243603",
     "exception": false,
     "start_time": "2024-11-23T14:46:35.136137",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred_df['score'] = (pred_df['raw_score'] == '是').astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "423897f2-786e-415b-a613-55a4359faf76",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:08:40.009817Z",
     "iopub.status.busy": "2024-11-24T09:08:40.008081Z",
     "iopub.status.idle": "2024-11-24T09:08:40.022740Z",
     "shell.execute_reply": "2024-11-24T09:08:40.020729Z",
     "shell.execute_reply.started": "2024-11-24T09:08:40.009734Z"
    },
    "papermill": {
     "duration": 0.094328,
     "end_time": "2024-11-23T14:46:35.431162",
     "exception": false,
     "start_time": "2024-11-23T14:46:35.336834",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.74"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['score'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "79325429-9cf1-4e2c-95ac-cb0c1a3b6156",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-24T09:08:40.026804Z",
     "iopub.status.busy": "2024-11-24T09:08:40.026036Z",
     "iopub.status.idle": "2024-11-24T09:08:40.261879Z",
     "shell.execute_reply": "2024-11-24T09:08:40.261377Z",
     "shell.execute_reply.started": "2024-11-24T09:08:40.026730Z"
    },
    "papermill": {
     "duration": 0.289336,
     "end_time": "2024-11-23T14:46:35.804651",
     "exception": false,
     "start_time": "2024-11-23T14:46:35.515315",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving to ../experiments/retrieval_v13_contextual_embeddings_deepseek/eval_df.pkl\n"
     ]
    }
   ],
   "source": [
    "save_path = os.path.join(expr_dir, 'eval_df.pkl')\n",
    "\n",
    "if not os.path.exists(save_path):\n",
    "    print(f'saving to {save_path}')\n",
    "    pickle.dump(pred_df, open(save_path, 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88fc7227-9c21-48da-b179-5070406eb113",
   "metadata": {
    "papermill": {
     "duration": 0.088622,
     "end_time": "2024-11-23T14:46:36.016801",
     "exception": false,
     "start_time": "2024-11-23T14:46:35.928179",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 1058.563616,
   "end_time": "2024-11-23T14:46:37.625874",
   "environment_variables": {},
   "exception": null,
   "input_path": "13_contextual_embeddings.ipynb",
   "output_path": "run_13_contextual_embeddings.ipynb",
   "parameters": {},
   "start_time": "2024-11-23T14:28:59.062258",
   "version": "2.6.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "0cd8c168767249f2a5fa412173f6e751": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_5ce1d1d9d86c40d9839877ff95734491",
       "max": 100,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_231702cf4d79477f9d5548665a1b18fe",
       "tabbable": null,
       "tooltip": null,
       "value": 100
      }
     },
     "2133bb8d85d34b8db112b4408ad60320": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "231702cf4d79477f9d5548665a1b18fe": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "23b1ad9c0f9c46c888da66e85c90eb84": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "24e6eadc3dc940ecabf30dd1a3c6d1f3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_fa4bddf2c33241b5bf918054518f128f",
       "max": 52,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_edc33e82be8f41eba6a18a0ef074ab7a",
       "tabbable": null,
       "tooltip": null,
       "value": 52
      }
     },
     "2f60367b1c8941e2bf71661c33969ae8": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "3865f25c78aa46f29a25d807205281c3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "3d0b06deaa654b989eece8cde06fa0f8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "3f8ceda83287475b97608e42f5f6782f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "4881e496f1c84fe29ce9ebebaddfb3c2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_bd096d5d219a467786a85cfe1613fedd",
        "IPY_MODEL_24e6eadc3dc940ecabf30dd1a3c6d1f3",
        "IPY_MODEL_bc2b8104b4244d8cacedeb95e800d91c"
       ],
       "layout": "IPY_MODEL_6b9a8e43c1c342dba500a14e7149b600",
       "tabbable": null,
       "tooltip": null
      }
     },
     "5ce1d1d9d86c40d9839877ff95734491": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "5ddb08be5cc64c9ab40a1d62a21763a5": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_86283159049d48b1adcfb2de2d404d4d",
       "placeholder": "​",
       "style": "IPY_MODEL_2133bb8d85d34b8db112b4408ad60320",
       "tabbable": null,
       "tooltip": null,
       "value": " 100/100 [08:34&lt;00:00, 10.01s/it]"
      }
     },
     "5ef9d83ccad1471f85335900a24a8553": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "6b9a8e43c1c342dba500a14e7149b600": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "816a079a8c804fbfa9b9a74f941abea8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_bcc69ec5db1b4aab977807284c9290e7",
        "IPY_MODEL_0cd8c168767249f2a5fa412173f6e751",
        "IPY_MODEL_5ddb08be5cc64c9ab40a1d62a21763a5"
       ],
       "layout": "IPY_MODEL_d1178c6858284f788a80b5f2a14fd0b7",
       "tabbable": null,
       "tooltip": null
      }
     },
     "86283159049d48b1adcfb2de2d404d4d": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8ff8262c56604119883f4a5f13bb74ab": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_5ef9d83ccad1471f85335900a24a8553",
       "placeholder": "​",
       "style": "IPY_MODEL_e89e77133c344fc48c1d62f5a607ec93",
       "tabbable": null,
       "tooltip": null,
       "value": " 8/8 [00:18&lt;00:00,  2.27s/it]"
      }
     },
     "9189a076554543aaa6f5ee04e40dbe1b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "988e6697a2af486fadeaf0b84347b565": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_e1aae4c55cb64f379e74f15357275628",
        "IPY_MODEL_fd9e23198ca1489a9773fda3510bf857",
        "IPY_MODEL_8ff8262c56604119883f4a5f13bb74ab"
       ],
       "layout": "IPY_MODEL_d2ee15001d2244529f7e47d3333c0f8e",
       "tabbable": null,
       "tooltip": null
      }
     },
     "9fc7d91f94a94933bde5ba80e64587de": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "a7d240a289084bdfba4724c0efd5ab07": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "bc2b8104b4244d8cacedeb95e800d91c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_2f60367b1c8941e2bf71661c33969ae8",
       "placeholder": "​",
       "style": "IPY_MODEL_9fc7d91f94a94933bde5ba80e64587de",
       "tabbable": null,
       "tooltip": null,
       "value": " 52/52 [04:26&lt;00:00,  4.22s/it]"
      }
     },
     "bcc69ec5db1b4aab977807284c9290e7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_3f8ceda83287475b97608e42f5f6782f",
       "placeholder": "​",
       "style": "IPY_MODEL_3d0b06deaa654b989eece8cde06fa0f8",
       "tabbable": null,
       "tooltip": null,
       "value": "100%"
      }
     },
     "bd096d5d219a467786a85cfe1613fedd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_3865f25c78aa46f29a25d807205281c3",
       "placeholder": "​",
       "style": "IPY_MODEL_9189a076554543aaa6f5ee04e40dbe1b",
       "tabbable": null,
       "tooltip": null,
       "value": "100%"
      }
     },
     "cc3ed8dc4a5c43aca7b62d904865b2fa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "cf68b6fe24964ce792aa63827489cb97": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "d1178c6858284f788a80b5f2a14fd0b7": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "d2ee15001d2244529f7e47d3333c0f8e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "e1aae4c55cb64f379e74f15357275628": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_23b1ad9c0f9c46c888da66e85c90eb84",
       "placeholder": "​",
       "style": "IPY_MODEL_cf68b6fe24964ce792aa63827489cb97",
       "tabbable": null,
       "tooltip": null,
       "value": "100%"
      }
     },
     "e89e77133c344fc48c1d62f5a607ec93": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "edc33e82be8f41eba6a18a0ef074ab7a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "fa4bddf2c33241b5bf918054518f128f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "fd9e23198ca1489a9773fda3510bf857": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_cc3ed8dc4a5c43aca7b62d904865b2fa",
       "max": 8,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_a7d240a289084bdfba4724c0efd5ab07",
       "tabbable": null,
       "tooltip": null,
       "value": 8
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
